/*!
 * Copyright (c) 2016 by Contributors
 * \file spatial_transformer.cc
 * \brief
 * \author Wei Wu
*/

#include "./spatial_transformer-inl.h"

namespace mshadow {
template<typename DType>
inline void BilinearSamplingForward(const Tensor<cpu, 4, DType> &output,
                                    const Tensor<cpu, 4, DType> &input,
                                    const Tensor<cpu, 3, DType> grid_src) {
  DType *out = output.dptr_;
  const DType *data = input.dptr_;
  const DType *grid = grid_src.dptr_;
  int o_n = output.size(0), o_c = output.size(1), o_h = output.size(2), o_w = output.size(3);
  int i_c = input.size(1), i_h = input.size(2), i_w = input.size(3);
  for (index_t n = 0; n < o_n; ++n) {
    for (index_t c = 0; c < o_c; ++c) {
      for (index_t h = 0; h < o_h; ++h) {
        for (index_t w = 0; w < o_w; ++w) {
          index_t out_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
          index_t grid_index = n * o_h * o_w * 2 + h * o_w + w;
          DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2;
          DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2;
          index_t top_left_y = std::min(i_h, std::max(0, static_cast<int>(floor(y_real))));
          index_t top_left_x = std::min(i_w, std::max(0, static_cast<int>(floor(x_real))));
          DType top_left_y_w = 1.0 - (y_real - top_left_y);
          DType top_left_x_w = 1.0 - (x_real - top_left_x);
          index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x;
          DType top_left_v = *(data + data_index);
          DType top_right_v = *(data + data_index + 1);
          DType bottom_left_v = *(data + data_index + i_w);
          DType bottom_right_v = *(data + data_index + i_w + 1);
          *(out+out_index) = top_left_v * top_left_y_w * top_left_x_w +
                             top_right_v * top_left_y_w * (1.0 - top_left_x_w) +
                             bottom_left_v * (1.0 - top_left_y_w) * top_left_x_w +
                             bottom_right_v * (1.0 - top_left_y_w) * (1.0 - top_left_x_w);
        }
      }
    }
  }
}

template<typename DType>
inline void BilinearSamplingBackward(const Tensor<cpu, 4, DType> &input_grad,
                                     const Tensor<cpu, 3, DType> &grid_src_data,
                                     const Tensor<cpu, 4, DType> &output_grad,
                                     const Tensor<cpu, 4, DType> &input_data) {
  DType *g_input = input_grad.dptr_;
  DType *grid_src = grid_src_data.dptr_;
  const DType *grad = output_grad.dptr_;
  const DType *data = input_data.dptr_;
  int o_n = output_grad.size(0), o_c = output_grad.size(1),
      o_h = output_grad.size(2), o_w = output_grad.size(3);
  int i_c = input_data.size(1), i_h = input_data.size(2), i_w = input_data.size(3);
  for (index_t n = 0; n < o_n; ++n) {
     for (index_t h = 0; h < o_h; ++h) {
        for (index_t w = 0; w < o_w; ++w) {
          DType top_left_y_gw = 0.0;
          DType top_left_x_gw = 0.0;
          index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w;
          DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2;
          DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2;
          index_t top_left_y = std::min(i_h, std::max(0, static_cast<int>(floor(y_real))));
          index_t top_left_x = std::min(i_w, std::max(0, static_cast<int>(floor(x_real))));
          DType top_left_y_w = 1.0 - (y_real - top_left_y);
          DType top_left_x_w = 1.0 - (x_real - top_left_x);
          for (index_t c = 0; c < o_c; ++c) {
            index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
            index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w
                                 + top_left_x;
            // calc 4 vertex value in input data
            DType top_left_v = *(data + data_index);
            DType top_right_v = *(data + data_index + 1);
            DType bottom_left_v = *(data + data_index + i_w);
            DType bottom_right_v = *(data + data_index + i_w + 1);
            // calc input grad
            *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w;
            *(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w
                                           * (1.0 - top_left_x_w);
            *(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w)
                                            * top_left_x_w;
            *(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w)
                                                * (1.0 - top_left_x_w);
            // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
            top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v +
                             (top_left_v - top_right_v - bottom_left_v + bottom_right_v)
                             * top_left_x_w);
            top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v +
                             (top_left_v - top_right_v - bottom_left_v + bottom_right_v)
                             * top_left_y_w);
          }
          // calc grid_src grad
          *(grid_src + grid_src_index + o_h * o_w) = top_left_y_gw * (i_h - 1) / 2;
          *(grid_src + grid_src_index) = top_left_x_gw * (i_w - 1) / 2;
        }
      }
    }
  }

}  // namespace mshadow

namespace mxnet {
namespace op {
template<>
Operator* CreateOp<cpu>(SpatialTransformerParam param, int dtype) {
  Operator *op = NULL;
  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
    op = new SpatialTransformerOp<cpu, DType>(param);
  })
  return op;
}

Operator *SpatialTransformerProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
                                     std::vector<int> *in_type) const {
  std::vector<TShape> out_shape, aux_shape;
  std::vector<int> out_type, aux_type;
  CHECK(InferType(in_type, &out_type, &aux_type));
  CHECK(InferShape(in_shape, &out_shape, &aux_shape));
  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
}

DMLC_REGISTER_PARAMETER(SpatialTransformerParam);

MXNET_REGISTER_OP_PROPERTY(SpatialTransformer, SpatialTransformerProp)
.add_argument("data", "Symbol", "Input data to the SpatialTransformerOp.")
.add_argument("loc", "Symbol", "localisation net, the output dim should be 6 when transform_type "
              "is affine, and the name of loc symbol should better starts with 'stn_loc', so that "
              "initialization it with iddentify tranform, or you shold initialize the weight and "
              "bias by yourself.")
.add_arguments(SpatialTransformerParam::__FIELDS__())
.describe("Apply spatial transformer to input feature map.");

}  // namespace op
}  // namespace mxnet
